package crypto

import (
	"crypto/rand"
	"io"
	"math/big"
	mathrand "math/rand"
	"testing"

	"github.com/stretchr/testify/require"
)

// genGAB is part of Diffie-Hellman key exchange.
//
// See https://en.wikipedia.org/wiki/Diffie%E2%80%93Hellman_key_exchange
//
// Values:
//
//	gA is g_a
//	dhPrime is dh_prime
//	gB is g_b
//	gAB is gab.
func genGAB(dhPrime, g, gA *big.Int, randSource io.Reader) (b, gB, gAB *big.Int, err error) {
	randMax := big.NewInt(0).SetBit(big.NewInt(0), RSAKeyBits, 1)
	// 6. Random number b is computed:
	if b, err = rand.Int(randSource, randMax); err != nil {
		return nil, nil, nil, err
	}
	gB, gAB = gab(dhPrime, b, g, gA)
	return b, gB, gAB, nil
}

func gab(dhPrime, b, g, gA *big.Int) (gB, gAB *big.Int) {
	gB = big.NewInt(0).Exp(g, b, dhPrime)
	gAB = big.NewInt(0).Exp(gA, b, dhPrime)
	return gB, gAB
}

func TestGAB(t *testing.T) {
	// https://core.telegram.org/mtproto/samples-auth_key#server-dh-inner-data-decomposition-using-the-following-formula
	g := big.NewInt(2)
	gA, ok := big.NewInt(0).SetString("262AABA621CC4DF587DC94CF8252258C"+
		"0B9337DFB47545A49CDD5C9B8EAE7236"+
		"C6CADC40B24E88590F1CC2CC762EBF1C"+
		"F11DCC0B393CAAD6CEE4EE5848001C73"+
		"ACBB1D127E4CB93072AA3D1C8151B6FB"+
		"6AA6124B7CD782EAF981BDCFCE9D7A00"+
		"E423BD9D194E8AF78EF6501F415522E4"+
		"4522281C79D906DDB79C72E9C63D83FB"+
		"2A940FF779DFB5F2FD786FB4AD71C9F0"+
		"8CF48758E534E9815F634F1E3A80A5E1"+
		"C2AF210C5AB762755AD4B2126DFA61A7"+
		"7FA9DA967D65DFD0AFB5CDF26C4D4E1A"+
		"88B180F4E0D0B45BA1484F95CB2712B5"+
		"0BF3F5968D9D55C99C0FB9FB67BFF56D"+
		"7D4481B634514FBA3488C4CDA2FC0659"+
		"990E8E868B28632875A9AA703BCDCE8F", 16)
	require.True(t, ok)
	dhPrimeStr := "C71CAEB9C6B1C9048E6C522F70F13F73" +
		"980D40238E3E21C14934D037563D930F" +
		"48198A0AA7C14058229493D22530F4DB" +
		"FA336F6E0AC925139543AED44CCE7C37" +
		"20FD51F69458705AC68CD4FE6B6B13AB" +
		"DC9746512969328454F18FAF8C595F64" +
		"2477FE96BB2A941D5BCD1D4AC8CC4988" +
		"0708FA9B378E3C4F3A9060BEE67CF9A4" +
		"A4A695811051907E162753B56B0F6B41" +
		"0DBA74D8A84B2A14B3144E0EF1284754" +
		"FD17ED950D5965B4B9DD46582DB1178D" +
		"169C6BC465B0D6FF9CA3928FEF5B9AE4" +
		"E418FC15E83EBEA0F87FA9FF5EED7005" +
		"0DED2849F47BF959D956850CE929851F" +
		"0D8115F635B105EE2E4E15D04B2454BF" +
		"6F4FADF034B10403119CD8E3B92FCC5B"
	dhPrime, ok := big.NewInt(0).SetString(dhPrimeStr, 16)
	require.True(t, ok)

	t.Run("Static", func(t *testing.T) {
		b, ok := big.NewInt(0).SetString("6F620AFA575C9233EB4C014110A7BCAF49464F798A18A0981FEA1E05E8DA"+
			"67D9681E0FD6DF0EDF0272AE3492451A84502F2EFC0DA18741A5FB80BD82296919A70FAA6D07CBBBCA2037EA7D3E327B61D"+
			"585ED3373EE0553A91CBD29B01FA9A89D479CA53D57BDE3A76FBD922A923A0A38B922C1D0701F53FF52D7EA9217080163A64901"+
			"E766EB6A0F20BC391B64B9D1DD2CD13A7D0C946A3A7DF8CEC9E2236446F646C42CFE2B60A2A8D776E56C8D7519B08B88ED0970E"+
			"10D12A8C9E355D765F2B7BBB7B4CA9360083435523CB0D57D2B106FD14F94B4EEE79D8AC131CA56AD389C84FE279716F8124A54"+
			"3337FB9EA3D988EC5FA63D90A4BA3970E7A39E5C0DE5", 16)
		require.True(t, ok)
		gB, gAB := gab(dhPrime, b, g, gA)

		if err := CheckDHParams(dhPrime, g, gA, gB); err != nil {
			t.Fatal(err)
		}
		if b == nil || gAB == nil {
			t.Fatal("nil")
		}
		gBVector, ok := big.NewInt(0).SetString("73700E7BFC7AEEC828EB8E0DCC04D09A"+
			"0DD56A1B4B35F72F0B55FCE7DB7EBB72"+
			"D7C33C5D4AA59E1C74D09B01AE536B31"+
			"8CFED436AFDB15FE9EB4C70D7F0CB14E"+
			"46DBBDE9053A64304361EB358A9BB32E"+
			"9D5C2843FE87248B89C3F066A7D5876D"+
			"61657ACC52B0D81CD683B2A0FA93E8AD"+
			"AB20377877F3BC3369BBF57B10F5B589"+
			"E65A9C27490F30A0C70FFCFD3453F5B3"+
			"79C1B9727A573CFFDCA8D23C721B135B"+
			"92E529B1CDD2F7ABD4F34DAC4BE1EEAF"+
			"60993DDE8ED45890E4F47C26F2C0B2E0"+
			"37BB502739C8824F2A99E2B1E7E41658"+
			"3417CC79A8807A4BDAC6A5E9805D4F61"+
			"86C37D66F6988C9F9C752896F3D34D25"+
			"529263FAF2670A09B2A59CE35264511F", 16)
		require.True(t, ok)
		if gBVector.Cmp(gB) != 0 {
			t.Fatal("mismatch")
		}
	})
	t.Run("Random", func(t *testing.T) {
		b, gB, gAB, err := genGAB(dhPrime, g, gA, mathrand.New(mathrand.NewSource(239)))
		if err != nil {
			t.Fatal(err)
		}
		if err := CheckDHParams(dhPrime, g, gA, gB); err != nil {
			t.Fatal(err)
		}
		bExpected, ok := big.NewInt(0).SetString("f6850b9cff75091df2d4d02d068868e95c0ec92e07c0dcee572705fcc"+
			"3ac766370b6a8d4fa17bf628247135c962156aa4ab6e173b699e2fa6cb607a0b3d35205ffd36635dd37572d132f7f16952c"+
			"0feb626be13e5165fae18d10cee45ecc9c83883a039903d354e46492d011953ddd8b619dbd4ad6c8fc9f3fab2d0cbcfacea"+
			"d28fc18fd7ee1310e9c0f98066204648ba1296e82c691cf87eaa9c2ef0d2775fb4c5d41432d77028c5640ae91e2d8b0033a7"+
			"dd4cee0aa87b57798b5afecee8c35a08ca9adcd3a753b3937f21b5363938e6efa5cca7ae45710c874a9b180a634eb2d6f7f7"+
			"7cd3e93418a5badfe2cea621a5f3e85f9bea16273b9dbf1924a6eb845", 16)
		require.True(t, ok)
		require.Zero(t, bExpected.Cmp(b), "b mismatch")

		gABExpected, ok := big.NewInt(0).SetString("35cc5280773903e351c5991b3dbbae8677cf73ea16fe08d2b52c88bac"+
			"5aa883f16d30013c55379419aac65487dc3e97afada383afcb41c424e393cf1acc7c347202ba74a94cf049c26639016f1dc60"+
			"6c0d29e31bf147f13528e34183a9ad26a4c6537cff68bd0e82bff69b9fbe118bba3732581fa9f372ef5228a8529fc5f4ed8ff"+
			"2fe0c5877e42ab45efaa9da36f1d2c6ffbd4c8d32f34a20579405ddc867b6da09f52499c20ac7def55938bfbcbe0b0047e5c1"+
			"d241cc83c65bab84f4997a5025b447d15e57f9a7b13ed9397de0d236ce50da1b55c51d91ed1214de5cad9654129b1ada3caf8"+
			"8028510209a30aa37ab81fdc7eb17160ba54c772a888fcd2549", 16)
		require.True(t, ok)
		require.Zero(t, gABExpected.Cmp(gAB), "g_ab mismatch")
	})
}
